{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple Kernel SHAP\n",
    "\n",
    "This notebook provides a simple brute force version of Kernel SHAP that enumerates the entire $2^M$ sample space. We also compare to the full KernelExplainer implementation. Note that KernelExplainer does a sampling approximation for large values of $M$, but for small values it is exact."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Brute Force Kernel SHAP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  reference = [0. 0. 0. 0.]\n",
      "          x = [ 1.62434536 -0.61175641 -0.52817175 -1.07296862]\n",
      "shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]\n",
      " base_value = 10.000000000000002\n",
      "   sum(phi) = 9.55093584213122\n",
      "       f(x) = 9.55093584213122\n"
     ]
    }
   ],
   "source": [
    "import itertools\n",
    "\n",
    "import numpy as np\n",
    "import scipy.special\n",
    "\n",
    "\n",
    "def powerset(iterable):\n",
    "    s = list(iterable)\n",
    "    return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(len(s) + 1))\n",
    "\n",
    "\n",
    "def shapley_kernel(M, s):\n",
    "    if s == 0 or s == M:\n",
    "        return 10000\n",
    "    return (M - 1) / (scipy.special.binom(M, s) * s * (M - s))\n",
    "\n",
    "\n",
    "def f(X):\n",
    "    np.random.seed(0)\n",
    "    beta = np.random.rand(X.shape[-1])\n",
    "    return np.dot(X, beta) + 10\n",
    "\n",
    "\n",
    "def kernel_shap(f, x, reference, M):\n",
    "    X = np.zeros((2**M, M + 1))\n",
    "    X[:, -1] = 1\n",
    "    weights = np.zeros(2**M)\n",
    "    V = np.zeros((2**M, M))\n",
    "    for i in range(2**M):\n",
    "        V[i, :] = reference\n",
    "\n",
    "    ws = {}\n",
    "    for i, s in enumerate(powerset(range(M))):\n",
    "        s = list(s)\n",
    "        V[i, s] = x[s]\n",
    "        X[i, s] = 1\n",
    "        ws[len(s)] = ws.get(len(s), 0) + shapley_kernel(M, len(s))\n",
    "        weights[i] = shapley_kernel(M, len(s))\n",
    "    y = f(V)\n",
    "    wsq = np.sqrt(weights)\n",
    "    result = np.linalg.lstsq(wsq[:, None] * X, wsq * y, rcond=None)[0]\n",
    "    return result\n",
    "\n",
    "\n",
    "M = 4\n",
    "np.random.seed(1)\n",
    "x = np.random.randn(M)\n",
    "reference = np.zeros(M)\n",
    "phi = kernel_shap(f, x, reference, M)\n",
    "base_value = phi[-1]\n",
    "shap_values = phi[:-1]\n",
    "\n",
    "print(\"  reference =\", reference)\n",
    "print(\"          x =\", x)\n",
    "print(\"shap_values =\", shap_values)\n",
    "print(\" base_value =\", base_value)\n",
    "print(\"   sum(phi) =\", np.sum(phi))\n",
    "print(\"       f(x) =\", f(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using KernelExplainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shap_values = [ 0.89146267 -0.43752168 -0.31836259 -0.58464256]\n",
      "base value = 10.0\n"
     ]
    }
   ],
   "source": [
    "import shap\n",
    "\n",
    "explainer = shap.KernelExplainer(f, np.reshape(reference, (1, len(reference))))\n",
    "shap_values = explainer.shap_values(x)\n",
    "print(\"shap_values =\", shap_values)\n",
    "print(\"base value =\", explainer.expected_value)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
